[Fix] Mixed Precision (float16) numerical instability in GroupNormalization with small epsilon#22589
[Fix] Mixed Precision (float16) numerical instability in GroupNormalization with small epsilon#22589ChiragSW wants to merge 6 commits intokeras-team:masterfrom
Conversation
…zation with small epsilon
There was a problem hiding this comment.
Code Review
This pull request updates the GroupNormalization layer to improve numerical stability during mixed precision training. It disables automatic casting for the layer and its weights (gamma and beta) and adds an explicit cast to compute_dtype at the end of the call method. New test cases are included to ensure that large input values do not result in NaNs when running within a float16 autocast scope. I have no feedback to provide.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #22589 +/- ##
===========================================
- Coverage 83.28% 24.73% -58.56%
===========================================
Files 596 596
Lines 68089 68254 +165
Branches 10607 10668 +61
===========================================
- Hits 56711 16884 -39827
- Misses 8634 50456 +41822
+ Partials 2744 914 -1830
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@hertschuh please review |
| ): | ||
| super().__init__(**kwargs) | ||
| self.supports_masking = True | ||
| self.autocast = False |
There was a problem hiding this comment.
We should not hardcode self.autocast = False. The fix is indeed to do this:
keras.layers.GroupNormalization(groups=8, epsilon=1e-12, autocast=False)But this should be controlled by users, not hardcoded.
The contract of autocast is to accept lower precision to improve speed, and that option should remain open to people who want it.
Now, we could print a warning if the epsilon is lower than the precision, because this is not achievable.
|
The suggested changes are done, please review @hertschuh |
| gamma_regularizer=None, | ||
| beta_constraint=None, | ||
| gamma_constraint=None, | ||
| autocast=True, |
There was a problem hiding this comment.
Please revert all the changes related to autocast. We don't want any behavior change for backwards compatibility. It should stay false by default and people should opt in of they want that behavior.
| initializer=self.gamma_initializer, | ||
| regularizer=self.gamma_regularizer, | ||
| constraint=self.gamma_constraint, | ||
| autocast=False, |
| initializer=self.beta_initializer, | ||
| regularizer=self.beta_regularizer, | ||
| constraint=self.beta_constraint, | ||
| autocast=False, |
| if self._epsilon_too_small_warning_issued: | ||
| return | ||
| if not self.autocast: | ||
| return | ||
| compute_dtype = backend.standardize_dtype(self.compute_dtype) | ||
| if compute_dtype not in ("float16", "bfloat16"): | ||
| return | ||
| try: | ||
| finfo = ml_dtypes.finfo(compute_dtype) | ||
| min_pos = getattr(finfo, "smallest_subnormal", finfo.tiny) | ||
| except Exception: | ||
| return | ||
| if self.epsilon != 0 and self.epsilon < float(min_pos): | ||
| warnings.warn( | ||
| "The configured `epsilon` is smaller than what can be " | ||
| f"represented in the layer compute dtype ({compute_dtype}); " | ||
| "it may be rounded to 0 under autocast. Consider increasing " | ||
| "`epsilon` or setting `autocast=False` for this layer.", | ||
| stacklevel=3, | ||
| ) | ||
| self._epsilon_too_small_warning_issued = True |
There was a problem hiding this comment.
Move this into __init__ and remove self._epsilon_too_small_warning_issued.
All of this will work after super().init() is called.
| if compute_dtype not in ("float16", "bfloat16"): | ||
| return |
There was a problem hiding this comment.
Remove this. The test on finfo below should apply if you're using float64 vs float32 for instance.
| return | ||
| try: | ||
| finfo = ml_dtypes.finfo(compute_dtype) | ||
| min_pos = getattr(finfo, "smallest_subnormal", finfo.tiny) |
There was a problem hiding this comment.
smallest_subnormal and tiny are way too small, it would never catch the issue in the bug.
I think you should use finfo.eps instead.
| with warnings.catch_warnings(record=True) as w: | ||
| warnings.simplefilter("always") | ||
| layer = layers.GroupNormalization( | ||
| groups=2, | ||
| axis=-1, | ||
| scale=False, | ||
| center=False, | ||
| epsilon=1e-12, | ||
| dtype="mixed_float16", | ||
| autocast=False, | ||
| ) | ||
| _ = layer(x) | ||
| self.assertFalse(any("epsilon" in str(x.message) for x in w)) |
There was a problem hiding this comment.
What's the difference between this test and the previous one line 212?
| epsilon=1e-12, | ||
| dtype="mixed_float16", |
There was a problem hiding this comment.
This combination should work after you use finfo.eps:
epsilon=1e-4,
dtype="mixed_bfloat16",
| def test_large_value_within_autocast_scope(self): | ||
| layer = layers.GroupNormalization(groups=2) | ||
| layer.build((1, 4, 4, 4)) | ||
| large_value = ops.full(layer.gamma.shape, 70000) | ||
| with backend.AutocastScope("float16"): | ||
| layer.gamma.assign(large_value) | ||
| self.assertAllClose(layer.gamma.value, large_value) | ||
|
|
||
| def test_mixed_float16_large_inputs(self): | ||
| layer = layers.GroupNormalization( | ||
| groups=2, axis=-1, scale=False, center=False | ||
| ) | ||
| x = np.full((1, 4, 4), 70000.0, dtype="float32") | ||
| with backend.AutocastScope("float16"): | ||
| output = layer(x) | ||
| output = backend.convert_to_numpy(output) | ||
| self.assertFalse(np.any(np.isnan(output))) | ||
|
|
||
| def test_autocast_is_user_controllable(self): | ||
| layer_default = layers.GroupNormalization() | ||
| self.assertTrue(layer_default.autocast) | ||
|
|
||
| layer_no_autocast = layers.GroupNormalization(autocast=False) | ||
| self.assertFalse(layer_no_autocast.autocast) |
There was a problem hiding this comment.
Remove these, I don't think these tests exercise anything new.
Root Cause
With autocast=True (it is true in default), inputs were cast to float16 before reaching call(). Values exceeding float16's max (65504) overflowed to inf, causing NaN propagation through normalization math. The existing internal float32 upcast couldn't recover already-lost values.
Fix
self.autocast = Falsekeeps inputs in their original dtype (float32), preventing overflowautocast=Falseon gamma/beta weights stores weights in float32 for precisionops.cast(outputs, self.compute_dtype)returns proper float16 output for mixed precisionI also added regression tests:
test_large_value_within_autocast_scope: verifies weights aren't corrupted by autocast (same test in BatchNormalization and LayerNormalization)test_mixed_float16_large_inputs: catches actual NaN bug